{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "dd0823d2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " 加权信息 : tensor([[[ 0.2412, -0.0175,  0.2173, -0.5284],\n",
      "         [ 0.2514, -0.0224,  0.2202, -0.5313],\n",
      "         [ 0.2364, -0.0168,  0.2185, -0.5227]],\n",
      "\n",
      "        [[-0.0085,  0.0158,  0.0018, -0.3527],\n",
      "         [-0.1217,  0.0532, -0.0055, -0.2765],\n",
      "         [-0.0809,  0.0917, -0.0446, -0.4277]]], grad_fn=<ViewBackward0>)\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "# 一个形状为 (batch_size, seq_len, feature_dim) 的张量 x\n",
    "x = torch.randn(2, 3, 4)  # 形状 (batch_size, seq_len, feature_dim) \n",
    "# 定义头数和每个头的维度\n",
    "num_heads = 2\n",
    "head_dim = 2\n",
    "# feature_dim 必须是 num_heads * head_dim 的整数倍\n",
    "assert x.size(-1) == num_heads * head_dim\n",
    "# 定义线性层用于将 x 转换为 Q, K, V 向量\n",
    "linear_q = torch.nn.Linear(4, 4)\n",
    "linear_k = torch.nn.Linear(4, 4)\n",
    "linear_v = torch.nn.Linear(4, 4)\n",
    "# 通过线性层计算 Q, K, V\n",
    "Q = linear_q(x)  # 形状 (batch_size, seq_len, feature_dim) \n",
    "K = linear_k(x)  # 形状 (batch_size, seq_len, feature_dim) \n",
    "V = linear_v(x)  # 形状 (batch_size, seq_len, feature_dim) \n",
    "# 将 Q, K, V 分割成 num_heads 个头\n",
    "def split_heads(tensor, num_heads):\n",
    "    batch_size, seq_len, feature_dim = tensor.size()\n",
    "    head_dim = feature_dim // num_heads\n",
    "    output = tensor.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2)\n",
    "    return  output # 形状 (batch_size, num_heads, seq_len, feature_dim)\n",
    "Q = split_heads(Q, num_heads)  # 形状 (batch_size, num_heads, seq_len, head_dim)\n",
    "K = split_heads(K, num_heads)  # 形状 (batch_size, num_heads, seq_len, head_dim)\n",
    "V = split_heads(V, num_heads)  # 形状 (batch_size, num_heads, seq_len, head_dim)\n",
    "# 计算 Q 和 K 的点积，作为相似度分数 , 也就是自注意力原始权重\n",
    "raw_weights = torch.matmul(Q, K.transpose(-2, -1))  # 形状 (batch_size, num_heads, seq_len, seq_len)\n",
    "# 对自注意力原始权重进行缩放\n",
    "scale_factor = K.size(-1) ** 0.5\n",
    "scaled_weights = raw_weights / scale_factor  # 形状 (batch_size, num_heads, seq_len, seq_len)\n",
    "# 对缩放后的权重进行 softmax 归一化，得到注意力权重\n",
    "attn_weights = F.softmax(scaled_weights, dim=-1)  # 形状 (batch_size, num_heads, seq_len, seq_len)\n",
    "# 将注意力权重应用于 V 向量，计算加权和，得到加权信息\n",
    "attn_outputs = torch.matmul(attn_weights, V)  # 形状 (batch_size, num_heads, seq_len, head_dim)\n",
    "# 将所有头的结果拼接起来\n",
    "def combine_heads(tensor):\n",
    "    batch_size, num_heads, seq_len, head_dim = tensor.size()\n",
    "    feature_dim = num_heads * head_dim\n",
    "    output = tensor.transpose(1, 2).contiguous().view(batch_size, seq_len, feature_dim)\n",
    "    return output# 形状 : (batch_size, seq_len, feature_dim)\n",
    "attn_outputs = combine_heads(attn_outputs)  # 形状 (batch_size, seq_len, feature_dim) \n",
    "# 对拼接后的结果进行线性变换\n",
    "linear_out = torch.nn.Linear(4, 4)\n",
    "attn_outputs = linear_out(attn_outputs)  # 形状 (batch_size, seq_len, feature_dim) \n",
    "print(\" 加权信息 :\", attn_outputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "219601c1",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
